--[[
	experiments for nips 2017 paper @ lognormal
]]--

local params = {
	numStates = 20,
	numActions = 1,
	
	pForward = 0.1,
	usePTable = false,
	
	pBranch = 1,   -- if pBranch == 1 then it's single-path (never go to path2)
	
	rGoal1 = 1e8,	-- note: 100:1 for making VI error small, i.e. rGoal=100 <-> reward is 1
	rGoal2 = -1e8;	-- UPDATE: ugh, should have directly changed VI threshold...
	--rStep = -50,
	goalType = 1,
	-- goalType == 1  <-->  G = 1, aka Type-I chain
	-- goalType == 0  <-->  G = 1/inf, aka Type-IV chain
	
	gamma = 0.9,
	
	numRuns = 1000,
}

local P = {};
for i = 1, params.numStates-1 do
	if not params.usePTable then
		P[i] = params.pForward;
	else
		print("use p table here")
	end
end

-- number of samples. Note: NOT for each state, but for each set of runs.
--local numSamples = {10, 20, 30, 40, 50, 60, 70, 80};
local numSamples = {}
for j = 1, 1 do
	numSamples[j] = 20 * j + 200
end
local numSample

local stateBranch = 1

local function Step(state, action)
	if action == 1 then
		if state == params.numStates then
			return (params.goalType == 1 and state or 1), stateBranch == 1 and params.rGoal1 or params.rGoal2
		else
			if math.random() > P[state] then
				return state, 0
			else
				if state == 1 then
					if math.random() <= params.pBranch then
						stateBranch = 1
					else
						stateBranch = 2
					end
				end
				return state+1, 0
			end
		end
	end
end

local function ValueIteration(T, R, V, Q, adjTable)
--
	local counter = 0
	while true do
		counter = counter + 1
		local delta = 0
		for s = 1, params.numStates do
			local maxQ = -1e20
			for a = 1, params.numActions do
				if adjTable[s][a].n > 0 then
					local sum = 0
					for i = 1, adjTable[s][a].n do
						sum = sum + T[s][a][ adjTable[s][a][i] ]*(
								R[s][a][ adjTable[s][a][i] ] + 
								params.gamma*V[ adjTable[s][a][i] ]
							)
					end
					
					delta = math.max(delta, math.abs(Q[s][a]-sum))
					Q[s][a] = sum
					maxQ = math.max(maxQ, sum)
				end
			end
			if maxQ > -1e20 then
				V[s] = maxQ
			end	
		end
		if delta < 1e-10 then
			break
		end
	end
	--io.write("<VI> ", counter, " iterations\n")
end

local function Run()
	local Q = {}
	local V = {}
	local T = {}
	local R = {}
	local adjTable = {}
	local Nsa = {}
	local Nsas = {}
	for i = 1, params.numStates do
		V[i] = 0
		Q[i] = {}
		T[i] = {}
		R[i] = {}
		adjTable[i] = {}
		Nsa[i] = {}
		Nsas[i] = {}
		for j = 1, params.numActions do
			Q[i][j] = 0
			T[i][j] = {}
			R[i][j] = {}
			Nsas[i][j] = {}
			adjTable[i][j] = {n = 0}
		end
	end
	
	for s = 1, params.numStates do
		for a = 1, params.numActions do
			Nsa[s][a] = numSample
			
			for t = 1, numSample do
				local s2, r = Step(s, a)
				--print(s, a, s2, r)
				if not Nsas[s][a][s2] then 
					Nsas[s][a][s2] = 1
					adjTable[s][a].n = adjTable[s][a].n + 1
					adjTable[s][a][adjTable[s][a].n] = s2
					R[s][a][s2] = r
				else
					Nsas[s][a][s2] = Nsas[s][a][s2] + 1
				end
			end
			
			for k = 1, adjTable[s][a].n do
				T[s][a][ adjTable[s][a][k] ] = Nsas[s][a][ adjTable[s][a][k] ] / Nsa[s][a]
				--print(">> ", s, a, adjTable[s][a][k], T[s][a][ adjTable[s][a][k] ])
			end
				
		end
	end
	
	ValueIteration(T, R, V, Q, adjTable)
	
	--print(Q[1][1])
	if fOut then 
		fOut:write(Q[1][1])
		fOut:write('\n')
	end
end

math.randomseed(os.time())
math.random()
math.random()
math.random()
math.random()

local pathPrefix = os.date("logsNew/%y%m%d_%H%M%S")
-- IMPORTANT: create folder not working, do it by hand!!! --

local fLog = io.open(pathPrefix..".params.txt", "w")
for k, v in pairs(params) do
	fLog:write(k .. " = " .. tostring(v) .. "\n")
end
fLog:close()

for _, v in ipairs(numSamples) do
	numSample = v
	fOut = io.open(pathPrefix.."_"..numSample..".v.txt", "w")
	for i = 1, params.numRuns do
		Run()
	end
	fOut:close()
end
